{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "omegafold_hacks.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/omegafold_hacks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#**OmegaFold_hacks**\n",
        "**WARNING** This is an experimental notebook that implements various hacks (unintended uses) of OmegaFold. It has not been optimized or benchmarked. Use at your own risk!\n",
        "\n",
        "For detail about the original model see: [Github](https://github.com/HeliXonProtein/OmegaFold), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1)\n",
        "\n",
        "\n",
        "#### **Tips and Instructions**\n",
        "- click the little ▶ play icon to the left of each cell below.\n",
        "\n",
        "#### **Experimental Options**\n",
        "- use \"/\" to specify chainbreaks, (eg. sequence=\"AAA/AAA\")\n",
        "- if confidence of prediction is low, try `use_real_msa` (currently only supported for monomeric and homo-oligomeric predictions)\n"
      ],
      "metadata": {
        "id": "qEgffIPOyEgk"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oVkGfduwiAkw"
      },
      "outputs": [],
      "source": [
        "#@markdown ##Install\n",
        "import os,sys,re\n",
        "from IPython.utils import io\n",
        "if \"SETUP_DONE\" not in dir():\n",
        "  import torch\n",
        "  device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "  with io.capture_output() as captured:\n",
        "    if not os.path.isdir(\"OmegaFold\"):\n",
        "      %shell git clone --quiet https://github.com/sokrypton/OmegaFold.git\n",
        "      %shell cd OmegaFold; pip -q install -r requirements.txt\n",
        "      %shell pip -q install py3Dmol\n",
        "      %shell apt-get install aria2 -qq > /dev/null\n",
        "      %shell aria2c -q -x 16 https://helixon.s3.amazonaws.com/release1.pt\n",
        "      %shell mkdir -p ~/.cache/omegafold_ckpt\n",
        "      %shell mv release1.pt ~/.cache/omegafold_ckpt/model.pt\n",
        "      %shell wget -qnc https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz\n",
        "      %shell tar xfz hhsuite-3.3.0-SSE2-Linux.tar.gz\n",
        "      %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py\n",
        "  os.environ['PATH'] += \":/content/bin:/content/scripts\"\n",
        "  SETUP_DONE = True"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown ##Prep inputs\n",
        "from string import ascii_uppercase, ascii_lowercase\n",
        "import colabfold as cf\n",
        "alphabet_list = list(ascii_uppercase+ascii_lowercase)\n",
        "\n",
        "sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n",
        "jobname = \"test\" #@param {type:\"string\"}\n",
        "\n",
        "# filter inputs\n",
        "sequence = re.sub(\"[^A-Z:/]\", \"\", sequence.upper())\n",
        "sequence = re.sub(\":+\",\":\",sequence)\n",
        "sequence = re.sub(\"/+\",\"/\",sequence)\n",
        "sequence = re.sub(\"^[:/]+\",\"\",sequence)\n",
        "sequence = re.sub(\"[:/]+$\",\"\",sequence)\n",
        "jobname = re.sub(r'\\W+', '', jobname)[:50]\n",
        "\n",
        "num_msa = 16\n",
        "use_real_msa = False #@param {type:\"boolean\"}\n",
        "\n",
        "sequence = re.sub(\"[^A-Z\\/\\:]\", \"\", sequence.upper()).replace(\"/\",\":\")\n",
        "ID = jobname+\"_\"+cf.get_hash(sequence)[:5]\n",
        "seqs = sequence.split(\":\")\n",
        "lengths = [len(s) for s in seqs]\n",
        "u_seqs = list(set(seqs))\n",
        "\n",
        "if len(seqs) == 1: mode = \"mono\"\n",
        "elif len(u_seqs) == 1: mode = \"homo\"\n",
        "else: mode = \"hetero\"\n",
        "\n",
        "if use_real_msa and mode == \"hetero\":\n",
        "  print(\"ERROR: the msa mode is currently not supported for hetero-multimers\")\n",
        "  use_real_msa = False\n",
        "\n",
        "if use_real_msa:\n",
        "  ID_ = cf.get_hash(seqs[0])[:5]\n",
        "  a3m_lines = cf.run_mmseqs2(seqs[0], ID_, filter=True)[0]\n",
        "  with open(f\"{ID}.a3m\",\"w\") as a3m:\n",
        "    a3m.write(a3m_lines)\n",
        "  \n",
        "  %shell hhfilter -v 0 -qid 30 -cov 75 -diff {num_msa} -i {ID}.a3m -o {ID}.fasta\n",
        "  \n",
        "  a3m_lines = open(f\"{ID}.fasta\",\"r\").readlines()\n",
        "  with open(f\"{ID}.fasta\",\"w\") as a3m:\n",
        "    n = 0\n",
        "    for a3m_line in a3m_lines:\n",
        "      if a3m_line.startswith(\">\"):\n",
        "        if n == 0:\n",
        "          a3m.write(f\">{ID}\\n\")\n",
        "        else:\n",
        "          a3m.write(a3m_line)\n",
        "      else:\n",
        "        a3m_line = \":\".join([a3m_line.rstrip()] * len(lengths))\n",
        "        a3m.write(f\"{a3m_line}\\n\")\n",
        "        n += 1\n",
        "      if n == num_msa: break\n",
        "    # add pseudo_msa\n",
        "    while n < num_msa:\n",
        "      a3m.write(f\">tmp\\n{sequence}\\n\")\n",
        "      n += 1\n",
        "else:\n",
        "  with open(f\"{ID}.fasta\",\"w\") as out:\n",
        "    out.write(f\">{ID}\\n{sequence}\\n\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "CFCwEAa2oZEN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%%time \n",
        "#@markdown ## Run OmegaFold\n",
        "num_cycle = 4 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\"] {type:\"raw\"}\n",
        "\n",
        "if use_real_msa:\n",
        "  %shell python ./OmegaFold/main.py --device={device} --num_cycle={num_cycle} --real_msa=1 {ID}.fasta .\n",
        "else:\n",
        "  %shell python ./OmegaFold/main.py --device={device} --num_cycle={num_cycle} --num_pseudo_msa={num_msa-1} {ID}.fasta .\n",
        "\n",
        "def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1):\n",
        "  if Ls is not None:\n",
        "    L_init = 0\n",
        "    new_chain = {}\n",
        "    for L,c in zip(Ls, alphabet_list):\n",
        "      new_chain.update({i:c for i in range(L_init,L_init+L)})\n",
        "      L_init += L  \n",
        "\n",
        "  n,num,pdb_out = 0,offset,[]\n",
        "  resnum_ = None\n",
        "  chain_ = None\n",
        "  new_chain_ = new_chain[0]\n",
        "  for line in pdb_str.split(\"\\n\"):\n",
        "    if line[:4] == \"ATOM\":\n",
        "      chain = line[21:22]\n",
        "      resnum = int(line[22:22+5])\n",
        "      if resnum_ is None: resnum_ = resnum\n",
        "      if chain_ is None: chain_ = chain\n",
        "      if resnum != resnum_ or chain != chain_:\n",
        "        num += (resnum - resnum_)  \n",
        "        n += 1\n",
        "        resnum_,chain_ = resnum,chain\n",
        "      if Ls is not None:\n",
        "        if new_chain[n] != new_chain_:\n",
        "          num = offset\n",
        "          new_chain_ = new_chain[n]\n",
        "      N = num if renum else resnum\n",
        "      if Ls is None: pdb_out.append(\"%s%4i%s\" % (line[:22],N,line[26:]))\n",
        "      else: pdb_out.append(\"%s%s%4i%s\" % (line[:21],new_chain[n],N,line[26:]))        \n",
        "  return \"\\n\".join(pdb_out)\n",
        "\n",
        "pdb_str = renum_pdb_str(open(f\"{ID}.pdb\",'r').read(), Ls=lengths)\n",
        "with open(f\"{ID}.pdb\",\"w\") as out:\n",
        "  out.write(pdb_str)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "_tUY92JL28DS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown ##Display\n",
        "import py3Dmol\n",
        "\n",
        "\n",
        "pymol_color_list = [\"#33ff33\",\"#00ffff\",\"#ff33cc\",\"#ffff00\",\"#ff9999\",\"#e5e5e5\",\"#7f7fff\",\"#ff7f00\",\n",
        "                    \"#7fff7f\",\"#199999\",\"#ff007f\",\"#ffdd5e\",\"#8c3f99\",\"#b2b2b2\",\"#007fff\",\"#c4b200\",\n",
        "                    \"#8cb266\",\"#00bfbf\",\"#b27f7f\",\"#fcd1a5\",\"#ff7f7f\",\"#ffbfdd\",\"#7fffff\",\"#ffff7f\",\n",
        "                    \"#00ff7f\",\"#337fcc\",\"#d8337f\",\"#bfff3f\",\"#ff7fff\",\"#d8d8ff\",\"#3fffbf\",\"#b78c4c\",\n",
        "                    \"#339933\",\"#66b2b2\",\"#ba8c84\",\"#84bf00\",\"#b24c66\",\"#7f7f7f\",\"#3f3fa5\",\"#a5512b\"]\n",
        "\n",
        "def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,\n",
        "             color=\"pLDDT\", chains=None, vmin=50, vmax=90,\n",
        "             size=(800,480), hbondCutoff=4.0,\n",
        "             Ls=None,\n",
        "             animate=False):\n",
        "  \n",
        "  if chains is None:\n",
        "    chains = 1 if Ls is None else len(Ls)\n",
        "  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n",
        "  if animate:\n",
        "    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})\n",
        "  else:\n",
        "    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})\n",
        "  if color == \"pLDDT\":\n",
        "    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})\n",
        "  elif color == \"rainbow\":\n",
        "    view.setStyle({'cartoon': {'color':'spectrum'}})\n",
        "  elif color == \"chain\":\n",
        "    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):\n",
        "       view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
        "  if show_sidechains:\n",
        "    BB = ['C','O','N']\n",
        "    HP = [\"ALA\",\"GLY\",\"VAL\",\"ILE\",\"LEU\",\"PHE\",\"MET\",\"PRO\",\"TRP\",\"CYS\",\"TYR\"]\n",
        "    view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
        "                  {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "    view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
        "                  {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "    view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
        "                  {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})  \n",
        "  if show_mainchains:\n",
        "    BB = ['C','O','N','CA']\n",
        "    view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
        "  view.zoomTo()\n",
        "  if animate: view.animate()\n",
        "  return view\n",
        "\n",
        "color = \"confidence\" #@param [\"confidence\", \"rainbow\", \"chain\"]\n",
        "if color == \"confidence\": color = \"pLDDT\"\n",
        "show_sidechains = False #@param {type:\"boolean\"}\n",
        "show_mainchains = False #@param {type:\"boolean\"}\n",
        "show_pdb(pdb_str, color=color, show_sidechains=show_sidechains, show_mainchains=show_mainchains,\n",
        "         Ls=lengths).show()"
      ],
      "metadata": {
        "cellView": "form",
        "id": "e6rpvs0Nqix5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Download prediction\n",
        "from google.colab import files\n",
        "files.download(f'{ID}.pdb')"
      ],
      "metadata": {
        "cellView": "form",
        "id": "12rxVAHSrmYQ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}